Skip to content

Fix WeightAveraging swapping in the un-updated average model during validation#21732

Open
ATOM00blue wants to merge 2 commits into
Lightning-AI:masterfrom
ATOM00blue:fix/weight-averaging-validation-swap-before-update
Open

Fix WeightAveraging swapping in the un-updated average model during validation#21732
ATOM00blue wants to merge 2 commits into
Lightning-AI:masterfrom
ATOM00blue:fix/weight-averaging-validation-swap-before-update

Conversation

@ATOM00blue
Copy link
Copy Markdown

What does this PR do?

Fixes #21724

WeightAveraging (and its EMAWeightAveraging subclass) creates the AveragedModel in setup() as a copy of the model's initial weights, with n_averaged == 0. The validation hooks on_validation_epoch_start / on_validation_epoch_end swapped this average model in unconditionally whenever it existed.

When using a delayed start (e.g. EMAWeightAveraging(update_starting_at_step=1000)) and validating during the warmup period, the average model has never been updated, so validation ran against the frozen initial (untrained) weights instead of the current trained ones. This is what the issue describes as metrics being near zero before update_starting_at_step.

This PR only swaps the models for validation once the average model has actually been updated at least once (n_averaged > 0). The swap remains balanced across the start/end hooks because n_averaged does not change during validation.

Tests

  • Added test_weight_averaging_no_swap_before_first_update, which verifies that during a never-reached delayed start the parameters seen at validation are the current trained weights, not the frozen initial snapshot. It fails before this change and passes after.
  • Updated SWATestCallback swap-count expectations: with a delayed update schedule, validation now only swaps once the average model has been updated.
Before submitting
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG?

PR review

Anyone in the community is welcome to review the PR.

…ation

Before its first update, the AveragedModel only holds the copy of the
initial weights made in setup(). The validation hooks swapped it in
unconditionally, so during a delayed-start warmup (e.g. EMAWeightAveraging
with update_starting_at_step) validation evaluated the untrained snapshot
instead of the current trained weights.

Only swap the models when the average model has been updated at least once
(n_averaged > 0). The swap stays balanced across validation start/end since
n_averaged does not change during validation.
@ATOM00blue ATOM00blue requested a review from tchaton as a code owner May 22, 2026 02:47
Copilot AI review requested due to automatic review settings May 22, 2026 02:47
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a bug in the WeightAveraging / EMAWeightAveraging callbacks where validation could swap in the averaged model even before it had received its first update (n_averaged == 0), causing validation to run on an untrained initial-weight snapshot during delayed-start warmup.

Changes:

  • Guarded validation-time model swapping so it only happens after the averaged model has been updated at least once (n_averaged > 0).
  • Updated SWA test expectations to reflect that validation swapping now starts only after the first averaging update occurs.
  • Added a regression test ensuring validation observes current trained weights (not the frozen initial snapshot) when the averaging update threshold is never reached.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
src/lightning/pytorch/callbacks/weight_averaging.py Prevents swapping in the averaged model for validation until n_averaged > 0, avoiding evaluation on an untrained snapshot.
tests/tests_pytorch/callbacks/test_weight_averaging.py Adjusts swap-count expectations and adds a regression test for “no swap before first update” behavior.
src/lightning/pytorch/CHANGELOG.md Documents the fix under the unreleased “Fixed” section.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Martovark
Copy link
Copy Markdown

I think it's better to use self._latest_update_step > 0 from WeightAveraging instead of self._average_model.n_averaged > 0, because most likely self._average_model.n_averaged lives on cuda and it requires device <-> host sync.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EMA validation swap models before first update_parameters

3 participants